import json
from channels.generic.websocket import WebsocketConsumer, AsyncJsonWebsocketConsumer, AsyncWebsocketConsumer, \
    AsyncConsumer
from rest_framework_jwt.authentication import jwt_decode_handler
from asgiref.sync import async_to_sync


class ChatConsumer(AsyncJsonWebsocketConsumer):
    chats = dict()

    async def connect(self):
        self.group_name = self.scope['url_route']['kwargs']['group_name']
        await self.channel_layer.group_add(self.group_name, self.channel_name)
        # 将用户添加至聊天组信息chats中
        try:
            ChatConsumer.chats[self.group_name].add(self)
        except:
            ChatConsumer.chats[self.group_name] = set([self])
        # 创建连接时调用
        await self.accept()

    async def disconnect(self, close_code):
        # 连接关闭时调用
        # 将关闭的连接从群组中移除
        await self.channel_layer.group_discard(self.group_name, self.channel_name)
        # 将该客户端移除聊天组连接信息
        ChatConsumer.chats[self.group_name].remove(self)
        await self.close()

    async def receive_json(self, message, **kwargs):
        print(message)
        to_user = message.get('fid')
        length = len(ChatConsumer.chats[self.group_name])
        if length > 1:
            await self.channel_layer.group_send(
                self.group_name,
                {
                    "type": "chat.message",
                    "message": message,
                },
            )
        else:
            await self.channel_layer.group_send(
                str(to_user),
                {
                    "type": "push.message",
                    "event":  message
                },
            )

    async def chat_message(self, event):
        await self.send_json({
            "message": event["message"],
        })


class PushConsumer(AsyncWebsocketConsumer):
    async def connect(self):
        self.group_name = self.scope['url_route']['kwargs']['username']
        await self.channel_layer.group_add(
            self.group_name,
            self.channel_name
        )
        await self.accept()

    async def disconnect(self, close_code):
        await self.channel_layer.group_discard(
            self.group_name,
            self.channel_name
        )

    async def push_message(self, event):
        await self.send(text_data=json.dumps({
            "event": event['event']
        }))
